{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this tutorial we'll show how to export squeezenet which is implemented and trained in pytorch to run on mobile devices.\n",
    "Before we start, you should have [pytorch](https://github.com/pytorch/pytorch), [caffe2](https://github.com/caffe2/caffe2), and [onnx](https://github.com/onnx/onnx) installed in your environment and cloned [AICamera](https://github.com/bwasti/AICamera).\n",
    "Please checkout their github page for installation instructions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Some standard imports\n",
    "import io\n",
    "import numpy as np\n",
    "import torch.onnx"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following implementation of squeezenet is from [torchvision](https://github.com/pytorch/vision/blob/master/torchvision/models/squeezenet.py)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import math\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.init as init\n",
    "import torch.utils.model_zoo as model_zoo\n",
    "\n",
    "\n",
    "__all__ = ['SqueezeNet', 'squeezenet1_0', 'squeezenet1_1']\n",
    "\n",
    "\n",
    "model_urls = {\n",
    "    'squeezenet1_0': 'https://download.pytorch.org/models/squeezenet1_0-a815701f.pth',\n",
    "    'squeezenet1_1': 'https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth',\n",
    "}\n",
    "\n",
    "\n",
    "class Fire(nn.Module):\n",
    "\n",
    "    def __init__(self, inplanes, squeeze_planes,\n",
    "                 expand1x1_planes, expand3x3_planes):\n",
    "        super(Fire, self).__init__()\n",
    "        self.inplanes = inplanes\n",
    "        self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)\n",
    "        self.squeeze_activation = nn.ReLU(inplace=True)\n",
    "        self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes,\n",
    "                                   kernel_size=1)\n",
    "        self.expand1x1_activation = nn.ReLU(inplace=True)\n",
    "        self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes,\n",
    "                                   kernel_size=3, padding=1)\n",
    "        self.expand3x3_activation = nn.ReLU(inplace=True)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.squeeze_activation(self.squeeze(x))\n",
    "        return torch.cat([\n",
    "            self.expand1x1_activation(self.expand1x1(x)),\n",
    "            self.expand3x3_activation(self.expand3x3(x))\n",
    "        ], 1)\n",
    "\n",
    "\n",
    "class SqueezeNet(nn.Module):\n",
    "\n",
    "    def __init__(self, version=1.0, num_classes=1000):\n",
    "        super(SqueezeNet, self).__init__()\n",
    "        if version not in [1.0, 1.1]:\n",
    "            raise ValueError(\"Unsupported SqueezeNet version {version}:\"\n",
    "                             \"1.0 or 1.1 expected\".format(version=version))\n",
    "        self.num_classes = num_classes\n",
    "        if version == 1.0:\n",
    "            self.features = nn.Sequential(\n",
    "                nn.Conv2d(3, 96, kernel_size=7, stride=2),\n",
    "                nn.ReLU(inplace=True),\n",
    "                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=False),\n",
    "                Fire(96, 16, 64, 64),\n",
    "                Fire(128, 16, 64, 64),\n",
    "                Fire(128, 32, 128, 128),\n",
    "                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=False),\n",
    "                Fire(256, 32, 128, 128),\n",
    "                Fire(256, 48, 192, 192),\n",
    "                Fire(384, 48, 192, 192),\n",
    "                Fire(384, 64, 256, 256),\n",
    "                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=False),\n",
    "                Fire(512, 64, 256, 256),\n",
    "            )\n",
    "        else:\n",
    "            self.features = nn.Sequential(\n",
    "                nn.Conv2d(3, 64, kernel_size=3, stride=2),\n",
    "                nn.ReLU(inplace=True),\n",
    "                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=False),\n",
    "                Fire(64, 16, 64, 64),\n",
    "                Fire(128, 16, 64, 64),\n",
    "                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=False),\n",
    "                Fire(128, 32, 128, 128),\n",
    "                Fire(256, 32, 128, 128),\n",
    "                nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=False),\n",
    "                Fire(256, 48, 192, 192),\n",
    "                Fire(384, 48, 192, 192),\n",
    "                Fire(384, 64, 256, 256),\n",
    "                Fire(512, 64, 256, 256),\n",
    "            )\n",
    "        # Final convolution is initialized differently form the rest\n",
    "        final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1)\n",
    "        self.classifier = nn.Sequential(\n",
    "            nn.Dropout(p=0.5),\n",
    "            final_conv,\n",
    "            nn.ReLU(inplace=True),\n",
    "            nn.AvgPool2d(13)\n",
    "        )\n",
    "\n",
    "        for m in self.modules():\n",
    "            if isinstance(m, nn.Conv2d):\n",
    "                if m is final_conv:\n",
    "                    init.normal(m.weight.data, mean=0.0, std=0.01)\n",
    "                else:\n",
    "                    init.kaiming_uniform(m.weight.data)\n",
    "                if m.bias is not None:\n",
    "                    m.bias.data.zero_()\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.features(x)\n",
    "        x = self.classifier(x)\n",
    "        return x.view(x.size(0), self.num_classes)\n",
    "\n",
    "\n",
    "def squeezenet1_1(pretrained=False, **kwargs):\n",
    "    r\"\"\"SqueezeNet 1.1 model from the `official SqueezeNet repo\n",
    "    <https://github.com/DeepScale/SqueezeNet/tree/master/SqueezeNet_v1.1>`_.\n",
    "    SqueezeNet 1.1 has 2.4x less computation and slightly fewer parameters\n",
    "    than SqueezeNet 1.0, without sacrificing accuracy.\n",
    "    Args:\n",
    "        pretrained (bool): If True, returns a model pre-trained on ImageNet\n",
    "    \"\"\"\n",
    "    model = SqueezeNet(version=1.1, **kwargs)\n",
    "    if pretrained:\n",
    "        model.load_state_dict(model_zoo.load_url(model_urls['squeezenet1_1']))\n",
    "    return model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can get the torch model by calling the following function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Get pretrained squeezenet model\n",
    "torch_model = squeezenet1_1(True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "and export the pytorch model as onnx model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from torch.autograd import Variable\n",
    "batch_size = 1    # just a random number\n",
    "\n",
    "# Input to the model\n",
    "x = Variable(torch.randn(batch_size, 3, 224, 224), requires_grad=True)\n",
    "\n",
    "# Export the model\n",
    "torch_out = torch.onnx._export(torch_model,             # model being run\n",
    "                               x,                       # model input (or a tuple for multiple inputs)\n",
    "                               \"squeezenet.onnx\",       # where to save the model (can be a file or file-like object)\n",
    "                               export_params=True)      # store the trained parameter weights inside the model file\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After that, we can prepare and run the model and verify that the result of the model running on pytorch matches the result running on onnx-caffe2 backend."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import onnx\n",
    "import caffe2.python.onnx.backend\n",
    "from onnx import helper\n",
    "\n",
    "# Load the ONNX GraphProto object. Graph is a standard Python protobuf object\n",
    "model = onnx.load(\"squeezenet.onnx\")\n",
    "\n",
    "# prepare the caffe2 backend for executing the model this converts the ONNX graph into a\n",
    "# Caffe2 NetDef that can execute it. Other ONNX backends, like one for CNTK will be\n",
    "# availiable soon.\n",
    "prepared_backend = caffe2.python.onnx.backend.prepare(model)\n",
    "\n",
    "# run the model in Caffe2\n",
    "\n",
    "# Construct a map from input names to Tensor data.\n",
    "# The graph itself contains inputs for all weight parameters, followed by the input image.\n",
    "# Since the weights are already embedded, we just need to pass the input image.\n",
    "# last input the grap\n",
    "W = {model.graph.input[0].name: x.data.numpy()}\n",
    "\n",
    "# Run the Caffe2 net:\n",
    "c2_out = prepared_backend.run(W)[0]\n",
    "\n",
    "# Verify the numerical correctness upto 3 decimal places\n",
    "np.testing.assert_almost_equal(torch_out.data.cpu().numpy(), c2_out, decimal=3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then we can export the model to run on mobile devices, leveraging the cross-platform capability of caffe2."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Export to mobile\n",
    "from caffe2.python.onnx.backend import Caffe2Backend as c2\n",
    "\n",
    "init_net, predict_net = c2.onnx_graph_to_caffe2_net(model)\n",
    "with open(\"squeeze_init_net.pb\", \"wb\") as f:\n",
    "    f.write(init_net.SerializeToString())\n",
    "with open(\"squeeze_predict_net.pb\", \"wb\") as f:\n",
    "    f.write(predict_net.SerializeToString())\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You'll see squeeze_init_net.pb and squeeze_predict_net.pb in the same directory of this notebook. Let's make sure it can run with predictor since that's what we'll use in the Mobile App."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1, 1000)\n"
     ]
    }
   ],
   "source": [
    "# Verify it runs with predictor\n",
    "with open(\"squeeze_init_net.pb\", \"rb\") as f:\n",
    "    init_net = f.read()\n",
    "with open(\"squeeze_predict_net.pb\", \"rb\") as f:\n",
    "    predict_net = f.read()\n",
    "from caffe2.python import workspace\n",
    "p = workspace.Predictor(init_net, predict_net)\n",
    "# The following code should run:\n",
    "img = np.random.rand(1, 3, 224, 224).astype(np.float32)\n",
    "result, = p.run([img])\n",
    "print(result.shape) # our model produces prediction for each of ImageNet 1000 classes"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Putting the models into the android app\n",
    "\n",
    "After we are sure that it runs with predictor, we can copy `squeeze_init_net.pb` and `squeeze_predict_net.pb` to \n",
    "[AICamera/app/src/main/assets]().\n",
    "\n",
    "Now we can open Android Studio and import the AICamera project, run the app by clicking the green play button.\n",
    "\n",
    "You can check [Caffe2 AI Camera tutorial](https://caffe2.ai/docs/AI-Camera-demo-android.html) for more details of how Caffe2 can be invoked in the Android mobile app."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
